{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Processing images and text with VLMs \n",
    "\n",
    "This notebook demonstrates how to utilize the `HuggingFaceTB/SmolVLM-Instruct` 4bit-quantized model for various multimodal tasks such as:\n",
    "- Visual Question Answering (VQA): Answering questions based on image content.\n",
    "- Text Recognition (OCR): Extracting and interpreting text in images.\n",
    "- Video Understanding: Describing videos through sequential frame analysis.\n",
    "\n",
    "By structuring prompts effectively, you can leverage the model for many applications, such as scene understanding, document analysis, and dynamic visual reasoning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the requirements in Google Colab\n",
    "# !pip install transformers datasets trl huggingface_hub bitsandbytes\n",
    "\n",
    "# Authenticate to Hugging Face\n",
    "from huggingface_hub import notebook_login\n",
    "notebook_login()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`low_cpu_mem_usage` was None, now default to True since model is quantized.\n",
      "You shouldn't move a model that is dispatched using accelerate hooks.\n",
      "Some kwargs in processor config are unused and will not have any effect: image_seq_len. \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'longest_edge': 1536}\n"
     ]
    }
   ],
   "source": [
    "import torch, PIL\n",
    "from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig\n",
    "from transformers.image_utils import load_image\n",
    "\n",
    "device = (\n",
    "    \"cuda\"\n",
    "    if torch.cuda.is_available()\n",
    "    else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    ")\n",
    "\n",
    "quantization_config = BitsAndBytesConfig(load_in_4bit=True)\n",
    "model_name = \"HuggingFaceTB/SmolVLM-Instruct\"\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_name,\n",
    "    quantization_config=quantization_config,\n",
    ").to(device)\n",
    "processor = AutoProcessor.from_pretrained(\"HuggingFaceTB/SmolVLM-Instruct\")\n",
    "\n",
    "print(processor.image_processor.size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Processing Images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with generating captions and answering questions about an image. We'll also explore processing multiple images.\n",
    "### 1. Single Image Captioning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.pixabay.com/photo/2024/11/20/09/14/christmas-9210799_1280.jpg\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.pixabay.com/photo/2024/11/23/08/18/christmas-9218404_1280.jpg\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import Image, display\n",
    "\n",
    "image_url1 = \"https://cdn.pixabay.com/photo/2024/11/20/09/14/christmas-9210799_1280.jpg\"\n",
    "display(Image(url=image_url1))\n",
    "\n",
    "image_url2 = \"https://cdn.pixabay.com/photo/2024/11/23/08/18/christmas-9218404_1280.jpg\"\n",
    "display(Image(url=image_url2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/duydl/Miniconda3/envs/py310/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:451: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['User:<image>Can you describe the image?\\nAssistant: The image is a scene of a person walking in a forest. The person is wearing a coat and a cap. The person is holding the hand of another person. The person is walking on a path. The path is covered with dry leaves. The background of the image is a forest with trees.']\n"
     ]
    }
   ],
   "source": [
    "# Load  one image\n",
    "image1 = load_image(image_url1)\n",
    "\n",
    "# Create input messages\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\"type\": \"image\"},\n",
    "            {\"type\": \"text\", \"text\": \"Can you describe the image?\"}\n",
    "        ]\n",
    "    },\n",
    "]\n",
    "\n",
    "# Prepare inputs\n",
    "prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
    "inputs = processor(text=prompt, images=[image1], return_tensors=\"pt\")\n",
    "inputs = inputs.to(device)\n",
    "\n",
    "# Generate outputs\n",
    "generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
    "generated_texts = processor.batch_decode(\n",
    "    generated_ids,\n",
    "    skip_special_tokens=True,\n",
    ")\n",
    "\n",
    "print(generated_texts)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Comparing Multiple Images\n",
    "The model can process and compare multiple images. Let's determine the common theme between two images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['User:<image>What event do they both represent?\\nAssistant: Christmas.']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Load images\n",
    "image2 = load_image(image_url2)\n",
    "\n",
    "# Create input messages\n",
    "messages = [\n",
    "    # {\n",
    "    #     \"role\": \"user\",\n",
    "    #     \"content\": [\n",
    "    #         {\"type\": \"image\"},\n",
    "    #         {\"type\": \"image\"},\n",
    "    #         {\"type\": \"text\", \"text\": \"Can you describe the two images?\"}\n",
    "    #     ]\n",
    "    # },\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\"type\": \"image\"},\n",
    "            {\"type\": \"image\"},\n",
    "            {\"type\": \"text\", \"text\": \"What event do they both represent?\"}\n",
    "        ]\n",
    "    },\n",
    "]\n",
    "\n",
    "# Prepare inputs\n",
    "prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
    "inputs = processor(text=prompt, images=[image1, image2], return_tensors=\"pt\")\n",
    "inputs = inputs.to(device)\n",
    "\n",
    "# Generate outputs\n",
    "generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
    "generated_texts = processor.batch_decode(\n",
    "    generated_ids,\n",
    "    skip_special_tokens=True,\n",
    ")\n",
    "\n",
    "print(generated_texts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 🔠 Text Recognition (OCR)\n",
    "VLM can also recognize and interpret text in images, making it suitable for tasks like document analysis.\n",
    "You could try experimenting on images with denser text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://cdn.pixabay.com/photo/2020/11/30/19/23/christmas-5792015_960_720.png\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['User:<image>What is written?\\nAssistant: MERRY CHRISTMAS AND A HAPPY NEW YEAR']\n"
     ]
    }
   ],
   "source": [
    "document_image_url = \"https://cdn.pixabay.com/photo/2020/11/30/19/23/christmas-5792015_960_720.png\"\n",
    "display(Image(url=document_image_url))\n",
    "\n",
    "# Load the document image\n",
    "document_image = load_image(document_image_url)\n",
    "\n",
    "# Create input message for analysis\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\"type\": \"image\"},\n",
    "            {\"type\": \"text\", \"text\": \"What is written?\"}\n",
    "        ]\n",
    "    }\n",
    "]\n",
    "\n",
    "# Prepare inputs\n",
    "prompt = processor.apply_chat_template(messages, add_generation_prompt=True)\n",
    "inputs = processor(text=prompt, images=[document_image], return_tensors=\"pt\")\n",
    "inputs = inputs.to(device)\n",
    "\n",
    "# Generate outputs\n",
    "generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
    "generated_texts = processor.batch_decode(\n",
    "    generated_ids,\n",
    "    skip_special_tokens=True,\n",
    ")\n",
    "\n",
    "print(generated_texts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Processing videos\n",
    "\n",
    "Visual-Language Models (VLMs) can process videos indirectly by extracting keyframes and reasoning over them in temporal order. While VLMs lack the temporal modeling capabilities of dedicated video models, they can still:\n",
    "- Describe actions or events by analyzing sampled frames sequentially.\n",
    "- Answer questions about videos based on representative keyframes.\n",
    "- Summarize video content by combining textual descriptions of multiple frames.\n",
    "\n",
    "Let experiment on an example:\n",
    "\n",
    "<video width=\"640\" height=\"360\" controls>\n",
    "  <source src=\"https://cdn.pixabay.com/video/2023/10/28/186794-879050032_large.mp4\" type=\"video/mp4\">\n",
    "  Your browser does not support the video tag.\n",
    "</video>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install opencv-python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Video\n",
    "import cv2\n",
    "import numpy as np\n",
    "\n",
    "def extract_frames(video_path, max_frames=50, target_size=None):\n",
    "    cap = cv2.VideoCapture(video_path)\n",
    "    if not cap.isOpened():\n",
    "        raise ValueError(f\"Could not open video: {video_path}\")\n",
    "    \n",
    "    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
    "    frame_indices = np.linspace(0, total_frames - 1, max_frames, dtype=int)\n",
    "\n",
    "    frames = []\n",
    "    for idx in frame_indices:\n",
    "        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)\n",
    "        ret, frame = cap.read()\n",
    "        if ret:\n",
    "            frame = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n",
    "            if target_size:\n",
    "                frames.append(resize_and_crop(frame, target_size))\n",
    "            else:\n",
    "                frames.append(frame)\n",
    "    cap.release()\n",
    "    return frames\n",
    "\n",
    "def resize_and_crop(image, target_size):\n",
    "    width, height = image.size\n",
    "    scale = target_size / min(width, height)\n",
    "    image = image.resize((int(width * scale), int(height * scale)), PIL.Image.Resampling.LANCZOS)\n",
    "    left = (image.width - target_size) // 2\n",
    "    top = (image.height - target_size) // 2\n",
    "    return image.crop((left, top, left + target_size, top + target_size))\n",
    "\n",
    "# Video link\n",
    "video_link = \"https://cdn.pixabay.com/video/2023/10/28/186794-879050032_large.mp4\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Response: User: Following are the frames of a video in temporal order.<image>Describe what the woman is doing.\n",
      "Assistant: The woman is hanging an ornament on a Christmas tree.\n"
     ]
    }
   ],
   "source": [
    "question = \"Describe what the woman is doing.\"\n",
    "\n",
    "def generate_response(model, processor, frames, question):\n",
    "\n",
    "    image_tokens = [{\"type\": \"image\"} for _ in frames]\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": [{\"type\": \"text\", \"text\": \"Following are the frames of a video in temporal order.\"}, *image_tokens, {\"type\": \"text\", \"text\": question}]\n",
    "        }\n",
    "    ]\n",
    "    inputs = processor(\n",
    "        text=processor.apply_chat_template(messages, add_generation_prompt=True),\n",
    "        images=frames,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(model.device)\n",
    "\n",
    "    outputs = model.generate(\n",
    "        **inputs, max_new_tokens=100, num_beams=5, temperature=0.7, do_sample=True, use_cache=True\n",
    "    )\n",
    "    return processor.decode(outputs[0], skip_special_tokens=True)\n",
    "\n",
    "\n",
    "# Extract frames from the video\n",
    "frames = extract_frames(video_link, max_frames=15, target_size=384)\n",
    "\n",
    "processor.image_processor.size = (384, 384)\n",
    "processor.image_processor.do_resize = False\n",
    "# Generate response\n",
    "response = generate_response(model, processor, frames, question)\n",
    "\n",
    "# Display the result\n",
    "# print(\"Question:\", question)\n",
    "print(\"Response:\", response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 💐 You're Done!\n",
    "\n",
    "This notebook demonstrated how to use a Visual-Language Model (VLM) such as formating prompts for multimodal tasks. By following the steps outlined here, you can experiment with VLMs and their applications.\n",
    "\n",
    "### Next Steps to Explore:\n",
    "- Experiment with more use cases of VLMs.\n",
    "- Collaborate with a colleague by reviewing their pull requests (PRs).\n",
    "- Contribute to improving this course material by opening an issue or submitting a PR to introduce new use cases, examples, or concepts.\n",
    "\n",
    "Happy exploring! 🌟"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
